import torch
from easydict import EasyDict
from torch import nn
from timm.layers import DropPath
from Backbones.Transition import transition
from Utils.Registry import Registry
from Utils.Tool import SwapAxes, _init_weights, index_points
from Backbones.Context_Aware_State_Space import statespace, curves
from Backbones.Point_Focused_Attention import point_focused

models = Registry('PointLearner')


class MovingGuidance(nn.Module):
    def __init__(self, curve_cfgs):
        super(MovingGuidance, self).__init__()
        self.serial = curves.build(curve_cfgs)

    def forward(self, coordinates):
        return self.serial(coordinates)


class TokenMixer(nn.Module):
    def __init__(self, point_focused_cfgs, context_state_cfgs):
        super(TokenMixer, self).__init__()
        ori_type = point_focused_cfgs.NAME
        if point_focused_cfgs.glob_indu_cfgs.inducing == 0:
            point_focused_cfgs.NAME = 'withoutGlobalIndu'

        self.point_focused = point_focused.build(point_focused_cfgs)
        self.context_state = statespace.build(context_state_cfgs)

        point_focused_cfgs.NAME = ori_type

    def forward(self, coordinates, features, code):
        features = self.point_focused(coordinates, features)
        features = self.context_state(features, code)
        return features


class ChannelMixer(nn.Module):
    def __init__(self, channels, mlp_ratio=0.0):
        super(ChannelMixer, self).__init__()
        hidden_channels = channels * int(mlp_ratio)

        self.mlp = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(channels, hidden_channels, 1),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(hidden_channels, channels, 1),
            SwapAxes(),
        )

    def forward(self, x):
        return self.mlp(x)


class PointLearnerBlock(nn.Module):
    def __init__(self, channels, mlp_ratio, drop_path, point_focused_cfgs, context_state_cfgs):
        super(PointLearnerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(channels)

        point_focused_cfgs.channels = channels
        context_state_cfgs.channels = channels
        self.token_mixer = TokenMixer(point_focused_cfgs, context_state_cfgs)
        self.norm2 = nn.LayerNorm(channels)
        self.channel_mixer = ChannelMixer(channels, mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, coordinates, features, code):
        shortcut = features
        features = self.token_mixer(coordinates, self.norm1(features), code)
        features = self.drop_path(features) + shortcut

        shortcut = features
        features = self.channel_mixer(self.norm2(features))
        features = self.drop_path(features) + shortcut

        return features


class TransitionDown(nn.Module):
    def __init__(self, trans_down_cfgs):
        super(TransitionDown, self).__init__()
        self.down = transition.build(trans_down_cfgs)

    def forward(self, coordinates, features):
        center_idx, center, group_input = self.down(coordinates, features)
        return center_idx, center, group_input


class TransitionUp(nn.Module):
    def __init__(self, trans_up_cfgs):
        super(TransitionUp, self).__init__()
        self.up = transition.build(trans_up_cfgs)

    def forward(self, xyz1, points1, xyz2, points2):
        coordinates, features = self.up(xyz1, points1, xyz2, points2)
        return coordinates, features


class Embedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Embedding, self).__init__()

        self.stem = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(in_channels, out_channels, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, 1),
            SwapAxes()
        )

    def forward(self, features):
        return self.stem(features)


@models.register_module()
class PointLearnerRec(nn.Module):
    def __init__(
            self,
            in_channels,
            mlp_ratio,
            drop_path,
            curve_cfgs,
            point_focused_cfgs,
            context_state_cfgs,
            trans_down_cfgs,
            stride,
            enc_depths,
            enc_channels,
            num_category
    ):
        super(PointLearnerRec, self).__init__()
        self.enc_depths = enc_depths

        self.stem = Embedding(in_channels, enc_channels[0])
        self.guidance = MovingGuidance(curve_cfgs)

        self.trans_down, self.enc_blocks = nn.ModuleList(), nn.ModuleList()
        for i in range(len(enc_depths)):
            if i > 0:
                trans_down_cfgs.num_group //= stride[i - 1]
                trans_down_cfgs.in_channels = enc_channels[i - 1]
                trans_down_cfgs.out_channels = enc_channels[i]
                trans_down_cfgs.group_size = point_focused_cfgs.loc_perc_cfgs.neighbours
                self.trans_down.append(TransitionDown(trans_down_cfgs))

            point_focused_cfgs.glob_indu_cfgs.inducing = trans_down_cfgs.num_group // 8
            blocks = nn.ModuleList()
            for _ in range(enc_depths[i]):
                blocks.append(
                    PointLearnerBlock(
                        enc_channels[i],
                        mlp_ratio,
                        drop_path,
                        point_focused_cfgs,
                        context_state_cfgs
                    )
                )
            self.enc_blocks.append(blocks)

        self.cls_head = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(enc_channels[-1], 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(256, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(64, num_category, 1),
            SwapAxes()
        )

        self.apply(_init_weights)

    def forward(self, coordinates, features):
        features = self.stem(features)
        code = self.guidance(coordinates)

        for i in range(len(self.enc_depths)):
            if i > 0:
                center_idx, coordinates, features = self.trans_down[i - 1](coordinates, features)
                code = torch.gather(code, 1, center_idx)

            for j in range(self.enc_depths[i]):
                features = self.enc_blocks[i][j](coordinates, features, code)

        return self.cls_head(features).mean(dim=1)


@models.register_module()
class PointLearnerPartSeg(nn.Module):
    def __init__(
            self,
            in_channels,
            mlp_ratio,
            drop_path,
            curve_cfgs,
            point_focused_cfgs,
            context_state_cfgs,
            trans_down_cfgs,
            trans_up_cfgs,
            stride,
            enc_depths,
            enc_channels,
            dec_depths,
            dec_channels,
            num_category,
            num_parts
    ):
        super(PointLearnerPartSeg, self).__init__()

        self.enc_depths = enc_depths
        self.dec_depths = dec_depths

        self.stem = Embedding(in_channels, enc_channels[0])
        self.guidance = MovingGuidance(curve_cfgs)

        self.trans_down, self.enc_blocks = nn.ModuleList(), nn.ModuleList()
        for i in range(len(enc_depths)):
            if i > 0:
                trans_down_cfgs.num_group //= stride[i - 1]
                trans_down_cfgs.in_channels = enc_channels[i - 1]
                trans_down_cfgs.out_channels = enc_channels[i]
                trans_down_cfgs.group_size = point_focused_cfgs.loc_perc_cfgs.neighbours
                self.trans_down.append(TransitionDown(trans_down_cfgs))

            point_focused_cfgs.glob_indu_cfgs.inducing = trans_down_cfgs.num_group // 8
            blocks = nn.ModuleList()
            for _ in range(enc_depths[i]):
                blocks.append(
                    PointLearnerBlock(
                        enc_channels[i],
                        mlp_ratio,
                        drop_path,
                        point_focused_cfgs,
                        context_state_cfgs
                    )
                )
            self.enc_blocks.append(blocks)

        self.trans_up, self.dec_blocks = nn.ModuleList(), nn.ModuleList()
        for i in range(len(dec_depths)):
            if i > 0:
                trans_up_cfgs.in_channels_dec = dec_channels[i - 1]
                trans_up_cfgs.in_channels_enc = enc_channels[-i - 1]
                trans_up_cfgs.out_channels = dec_channels[i]
                self.trans_up.append(TransitionUp(trans_up_cfgs))
                point_focused_cfgs.glob_indu_cfgs.inducing *= stride[-i]

            blocks = nn.ModuleList()
            for _ in range(dec_depths[i]):
                blocks.append(
                    PointLearnerBlock(
                        dec_channels[i],
                        mlp_ratio,
                        drop_path,
                        point_focused_cfgs,
                        context_state_cfgs
                    )
                )
            self.dec_blocks.append(blocks)

        self.seg_head = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(dec_channels[-1] + num_category, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(64, num_parts, 1),
            SwapAxes(),
        )

    def forward(self, coordinates, features, cat_prompt):
        features = self.stem(features)
        code = self.guidance(coordinates)

        xyz_features_code = []
        for i in range(len(self.enc_depths)):
            if i > 0:
                center_idx, coordinates, features = self.trans_down[i - 1](coordinates, features)
                code = torch.gather(code, 1, center_idx)

            for j in range(self.enc_depths[i]):
                features = self.enc_blocks[i][j](coordinates, features, code)

            xyz_features_code.append((coordinates, features, code))

        coordinates, features, code = xyz_features_code[-1]
        for i in range(len(self.dec_depths)):
            if i > 0:
                coordinates, features = self.trans_up[i - 1](
                    coordinates,
                    features,
                    xyz_features_code[-i - 1][0],
                    xyz_features_code[-i - 1][1]
                )
                code = xyz_features_code[-i - 1][-1]

            for j in range(self.dec_depths[i]):
                features = self.dec_blocks[i][j](coordinates, features, code)

        features = torch.cat([features, cat_prompt.unsqueeze(dim=1).repeat(1, features.shape[1], 1)], dim=-1)

        return self.seg_head(features)


@models.register_module()
class PointLearnerSemSeg(nn.Module):
    def __init__(
            self,
            in_channels,
            mlp_ratio,
            drop_path,
            curve_cfgs,
            point_focused_cfgs,
            context_state_cfgs,
            trans_down_cfgs,
            trans_up_cfgs,
            inducing,
            stride,
            enc_depths,
            enc_channels,
            dec_depths,
            dec_channels,
            num_category
    ):
        super(PointLearnerSemSeg, self).__init__()
        self.enc_depths = enc_depths
        self.dec_depths = dec_depths

        self.stem = Embedding(in_channels, enc_channels[0])
        self.guidance = MovingGuidance(curve_cfgs)

        self.trans_down, self.enc_blocks = nn.ModuleList(), nn.ModuleList()
        for i in range(len(enc_depths)):
            if i > 0:
                trans_down_cfgs.stride = stride[i - 1]
                trans_down_cfgs.in_channels = enc_channels[i - 1]
                trans_down_cfgs.out_channels = enc_channels[i]
                trans_down_cfgs.group_size = point_focused_cfgs.loc_perc_cfgs.neighbours
                self.trans_down.append(TransitionDown(trans_down_cfgs))

            point_focused_cfgs.glob_indu_cfgs.inducing = inducing[i]
            blocks = nn.ModuleList()
            for _ in range(enc_depths[i]):
                blocks.append(
                    PointLearnerBlock(
                        enc_channels[i],
                        mlp_ratio,
                        drop_path,
                        point_focused_cfgs,
                        context_state_cfgs
                    )
                )
            self.enc_blocks.append(blocks)

        self.trans_up, self.dec_blocks = nn.ModuleList(), nn.ModuleList()
        for i in range(len(dec_depths)):
            if i > 0:
                trans_up_cfgs.in_channels_dec = dec_channels[i - 1]
                trans_up_cfgs.in_channels_enc = enc_channels[-i - 1]
                trans_up_cfgs.out_channels = dec_channels[i]
                self.trans_up.append(TransitionUp(trans_up_cfgs))
                point_focused_cfgs.glob_indu_cfgs.inducing = inducing[-i - 1]

            blocks = nn.ModuleList()
            for _ in range(dec_depths[i]):
                blocks.append(
                    PointLearnerBlock(
                        dec_channels[i],
                        mlp_ratio,
                        drop_path,
                        point_focused_cfgs,
                        context_state_cfgs
                    )
                )
            self.dec_blocks.append(blocks)

        self.seg_head = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(dec_channels[-1], 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(64, 32, 1),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv1d(32, num_category, 1),
            SwapAxes(),
        )

    def forward(self, coordinates, features):
        features = self.stem(features)
        code = self.guidance(coordinates)

        xyz_features_code = []
        for i in range(len(self.enc_depths)):
            if i > 0:
                center_idx, coordinates, features = self.trans_down[i - 1](coordinates, features)
                code = torch.gather(code, 1, center_idx)

            for j in range(self.enc_depths[i]):
                features = self.enc_blocks[i][j](coordinates, features, code)

            xyz_features_code.append((coordinates, features, code))

        coordinates, features, code = xyz_features_code[-1]
        for i in range(len(self.dec_depths)):
            if i > 0:
                coordinates, features = self.trans_up[i - 1](
                    coordinates,
                    features,
                    xyz_features_code[-i - 1][0],
                    xyz_features_code[-i - 1][1]
                )
                code = xyz_features_code[-i - 1][-1]

            for j in range(self.dec_depths[i]):
                features = self.dec_blocks[i][j](coordinates, features, code)

        return self.seg_head(features)


if __name__ == '__main__':
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    cfgs = {
        'NAME': 'PointLearnerRec',
        'in_channels': 6,
        'mlp_ratio': 4,
        'drop_path': 0.3,
        'curve_cfgs': {
            'NAME': 'Hilbert',
            'grid_size': 0.01
        },
        'point_focused_cfgs': {
            'NAME': 'PointFocusedAttention',
            'loc_perc_cfgs': {
                'NAME': 'LocalityPerception',
                'neighbours': 3,
                'loc_pos_cfgs': {
                    'NAME': 'withLocPos'
                }
            },
            'glob_indu_cfgs': {
                'NAME': 'GlobalInducing',
                'glob_pos_cfgs': {
                    'NAME': 'withGlobPos',
                }
            }
        },
        'context_state_cfgs': {
            'NAME': 'ContextAwareStateSpace',
            'ssm_cfgs': {
                'NAME': 'Bidirectional',
                'd_state': 16,
                'expand': 2
            }
        },
        'trans_down_cfgs': {
            'NAME': 'FarthestSampling',
            'num_group': 1024
        },
        'stride': [4, 4, 4],
        'enc_depths': [1, 1, 1, 1],
        'enc_channels': [48, 96, 192, 384],
        'num_classes': 40
    }

    cfgs = EasyDict(cfgs)
    model = models.build(cfgs).to(device)

    batch, N, channels = 8, 1024, 6
    coordinates, features = torch.rand((batch, N, 3)).to(device), torch.rand((batch, N, channels)).to(device)

    logit = model(coordinates, features)
    print(logit.shape)
